import torch.nn as nn
import torch
import math
import torch.nn.functional as F
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def partial_loss(output1, target):
    target = 1. - F.one_hot(target,10)
    output = F.softmax(output1, dim=1)
    l = target * torch.log(output)
    loss = (-torch.sum(l)) / l.size(0)
    return loss
# ovr
def ovr_loss(f, labels):
    y_bar = F.one_hot(labels,num_classes=10)
    y = 1. - F.one_hot(labels,num_classes=10)
    fy_bar = (torch.log(1+torch.exp(f))*y_bar).sum(1)
    fy = (torch.log(1+torch.exp(-1.*f))*y).mean(1)
    loss = (fy_bar + fy).sum()/len(labels)
    return loss


# weighted loss
def non_k_softmax_loss(f, K, labels):
    Q_1 = 1 - F.softmax(f, 1)
    Q_1 = F.softmax(Q_1, 1)
    labels = labels.long()
    return F.nll_loss(Q_1.log(), labels.long())  # Equation(5) in paper

def w_loss(f, K, labels):
    loss_class = non_k_softmax_loss(f=f, K=K, labels=labels)
    loss_w = w_loss_p(f=f, K=K, labels=labels)
    final_loss = loss_class + loss_w  # Equation(12) in paper
    return final_loss

def w_loss_p(f, K, labels):
    Q_1 = 1-F.softmax(f, 1)
    Q = F.softmax(Q_1, 1)
    q = torch.tensor(1.0) / torch.sum(Q_1, dim=1)
    q = q.view(-1, 1).repeat(1, K)
    w = torch.mul(Q_1, q)  # weight
    w_1 = torch.mul(w, Q.log())
    return F.nll_loss(w_1, labels.long())  # Equation(14) in paper
# nn
def non_negative_loss(f, K, labels, ccp, beta): # logits, 10, mincom, ccp, beta=0
    ccp = torch.from_numpy(ccp).float().to(device)
    neglog = -F.log_softmax(f, dim=1)  # (bs, K)
    loss_vector = torch.zeros(K, requires_grad=True).to(device)
    temp_loss_vector = torch.zeros(K).to(device)
    for k in range(K):
        idx = (labels == k)
        if torch.sum(idx).item() > 0:
            idxs = idx.view(-1, 1).repeat(1, K)  # (bs, K)
            neglog_k = torch.masked_select(neglog, idxs).view(-1, K)
            temp_loss_vector[k] = -(K-1) * ccp[k] * torch.mean(neglog_k, dim=0)[k]  # average of k-th class loss for k-th comp class samples
            # ccp[k] or ccp, refer to https://github.com/takashiishida/comp/issues/2
            loss_vector = loss_vector + torch.mul(ccp[k], torch.mean(neglog_k, dim=0))  # only k-th in the summation of the second term inside max
    loss_vector = loss_vector + temp_loss_vector
    count = np.bincount(labels.data.cpu()).astype('float')
    while len(count) < K:
        count = np.append(count, 0)  # when largest label is below K, bincount will not take care of them
    loss_vector_with_zeros = torch.cat((loss_vector.view(-1, 1), torch.zeros(K, requires_grad=True).view(-1, 1).to(device)-beta), 1)
    max_loss_vector, _ = torch.max(loss_vector_with_zeros, dim=1)
    final_loss = torch.sum(max_loss_vector)
    return final_loss #, torch.mul(torch.from_numpy(count).float().to(device), loss_vector)

# forward
def forward_loss(f, K, labels, reduction='mean'): # logits, 10, mincom
    Q = torch.ones(K, K) * 1/(K-1)  # uniform assumption
    Q = Q.to(device)
    for k in range(K):
        Q[k, k] = 0
    q = torch.mm(F.softmax(f, 1), Q)
    return F.nll_loss(q.log(), labels.long(), reduction=reduction)

# pc
def pc_loss(f, K, labels): # logits, 10, mincom
    sigmoid = nn.Sigmoid()
    fbar = f.gather(1, labels.long().view(-1, 1)).repeat(1, K)
    loss_matrix = sigmoid(-1. * (f - fbar))  # multiply -1 for "complementary"
    M1, M2 = K*(K-1)/2, K-1
    pc_loss = torch.sum(loss_matrix)*(K-1)/len(labels) - M1 + M2
    return pc_loss


# scl_nl
def scl_nl(logits, target, reduction='mean'):
    activated_prob = -torch.log(1 - F.softmax(logits, dim=1) + 1e-5)
    loss = -F.nll_loss(activated_prob, target, reduction=reduction)
    return loss
# scl_exp
def scl_exp(logits, target, reduction='mean'):
    activated_prob = torch.exp(F.softmax(logits, dim=1))
    loss = -F.nll_loss(activated_prob, target, reduction=reduction)
    return loss 